#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 25 21:35:12 2022

@author: matthewroberts
"""

import os
import sys
import numpy as np
import glob

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.gridspec as gridspec
import matplotlib.colors as colors

import cartopy
import cartopy.io.shapereader as shpreader
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import cartopy.crs as ccrs
import pyproj

import datetime as dt
import wrf
from netCDF4 import Dataset
import pandas as pd
from pyproj import Geod

import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

def planarWind(uVal,vVal,bearing):
    # Get angle difference between environmental wind and cross
    # section angle (All angles assumed starting at N on 0-360 plane)
    D = bearing - np.rad2deg((np.arctan2(uVal,vVal)))
    pWind = (np.sqrt(uVal**2+vVal**2))*np.cos(np.deg2rad(D))
    return pWind

def relax_zone_remover (input, sr):
    # remove extra points in fire grid
    output = input
    for _ in range(sr):
        output = np.delete(output, -1, 0)
        output = np.delete(output, -1, 1)
    return output

#########################
# Inputs
#########################

# BEAR FIRE
firename = 'bear_fire'
# fuel code
fuelstr = '_FUELx1'
# 'w'/'smk' for updrafts/smoke concentration for shaded plot
plot_type = 'w' 
# workaround for perimeters issue
start_thresh = 2
# average across multiple x-sections
AvgXsect = True
# how many x-sects? (even number)
xsect_val = 30
# for naming plots
keyword = 'ALL_x_sect_w'
# set ylim
maxY = 12500
# tick interval for x-distances
tick_dist_interval = 10
# info for where to plot wmax
xyloc = [62,11000,'right']

# # CALDOR FIRE
# firename = 'caldor_fire'
# # fuel code
# fuelstr = '_FUELx1'
# # 'w'/'smk' for updrafts/smoke concentration for shaded plot
# plot_type = 'w' 
# # workaround for perimeters issue
# start_thresh = 3
# # average across multiple x-sections
# AvgXsect = True
# # how many x-sects? (even number)
# xsect_val = 60
# # for naming plots
# keyword = 'ALL_x_sect_w'
# # set ylim
# maxY = 12500
# # tick interval for x-distances
# tick_dist_interval = 5
# # info for where to plot wmax
# xyloc = [1,11000,'left']

mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
filepath = mainpath+'/data/'+firename+'/'+fuelstr #data files location
savepath = mainpath+'/plots/'+firename+'/x_sect' #where to save figure
filelist_og = sorted(glob.glob(filepath+'/wrfout_d03*'))

filepath2 = mainpath+'/data/'+firename+'/_NOFIRE' #data files location
filelist2 = sorted(glob.glob(filepath2+'/wrfout_d03*'))

filelist = []
for f in filelist_og:
    filelist.append(os.path.basename(f))

############# TESTING #############
# # Caldor test
# filelist = [filelist[10]]
# filelist2 = [filelist2[10]]

# # Bear test
# filelist = [filelist[-9]]
# filelist2 = [filelist2[-9]]

filepath = mainpath+'/data/'+firename
fuels = ['_FUELx1', '_FUELx4', '_FUELx8']

# FOR TESTING
# filelist = ['/Users/matthewroberts/Documents/Projects/LEAPHI/data/'+firename+'/_FUELx4/wrfout_d03_2020-09-09_04:00:00']
# filelist = ['/Users/matthewroberts/Desktop/wrfout_d03_2020-09-08_23:30:00']
# filelist = ['/Users/matthewroberts/Desktop/wrfout_d03_2021-08-17_19:00:00']
# filelist2 = ['/Users/matthewroberts/Documents/Projects/LEAPHI/data/'+firename+'/_NOFIRE/wrfout_d03_2020-09-09_04:00:00']
# filelist2 = ['/Users/matthewroberts/Documents/Projects/LEAPHI/data/'+firename+'/wrf_nofire/wrfout_d03_2021-08-17_19:00:00']

# # Verify local directories
# print('\nSetting up '+keyword+' plots...')
# print('Mainpath: '+mainpath)
# print('Filepath: '+filepath)
# print('Savepath: '+savepath)
# # Check if directories exist, if not make them
# if not os.path.exists(mainpath):
#     sys.exit('Directory: '+mainpath+' does not exist!')
# if not os.path.exists(filepath):
#     sys.exit('Directory: '+filepath+' does not exist!')
# if not os.path.exists(savepath):
#     os.makedirs(savepath)

time_arr = []
for f in range(len(filelist)):
    
    ################
    # NO FIRE Control Run data
    wrf_file2 = Dataset(filelist2[f],mode='r')
    # Winds
    u2 = wrf.getvar(wrf_file2, "U", timeidx=-1)
    v2 = wrf.getvar(wrf_file2, "V", timeidx=-1)
    w2 = wrf.getvar(wrf_file2, "W", timeidx=-1)

    # Destagger appropriate dimensions
    u2 = wrf.destagger(u2,-1)
    v2 = wrf.destagger(v2,1)
    w2 = wrf.destagger(w2,0)
    ################

    plot_w_cross = []
    plot_smk_cross = []
    plot_qc_cross = []
    plot_qi_cross = []
    plot_comp_wind = []
    plot_lfn_cross = []
    plot_lfn_x = []
    plot_lfn_y = []
    
    for ff in fuels:
        
        wrf_file = Dataset(filepath+'/'+ff+'/'+filelist[f],mode='r')

        # Create timestamp and time array
        time = wrf.extract_times(wrf_file,wrf.ALL_TIMES)
        tstamp = str(pd.Timestamp(time[0]))
        time_arr.append(tstamp)
        print('\nOpening '+tstamp+'...', end="", flush=True)
    
        # Standard lat/lon grid
        terrain = wrf.getvar(wrf_file, "HGT", timeidx=-1)
        lat = wrf.getvar(wrf_file, "XLAT", timeidx=-1)
        lon = wrf.getvar(wrf_file, "XLONG", timeidx=-1)
        # Heights
        phb = wrf.getvar(wrf_file, "PHB", timeidx=-1)/9.80665 #divide by gravity to get height in meters
        ph = wrf.getvar(wrf_file, "PH", timeidx=-1)/9.80665
        z = ph+phb
        
        # create height array for plotting
        z = z[1:,:,:]
        # Smoke
        smoke = wrf.getvar(wrf_file, "fire_smoke", timeidx=-1)
        qc = wrf.getvar(wrf_file, "QCLOUD", timeidx=-1)*1000. # kg/kg to g/kg
        # qr = wrf.getvar(wrf_file, "QRAIN", timeidx=-1)*1000. # kg/kg to g/kg
        # qc = qc+qr
        
        # qs = wrf.getvar(wrf_file, "QSNOW", timeidx=-1)*1000. # kg/kg to g/kg
        # qg = wrf.getvar(wrf_file, "QGRAUP", timeidx=-1)*1000. # kg/kg to g/kg
        qi = wrf.getvar(wrf_file, "QICE", timeidx=-1)*1000. # kg/kg to g/kg
        # qi = qi+qg+qs
        
        # Winds
        u = wrf.getvar(wrf_file, "U", timeidx=-1)
        v = wrf.getvar(wrf_file, "V", timeidx=-1)
        w = wrf.getvar(wrf_file, "W", timeidx=-1)
    
        # Destagger appropriate dimensions and subract background flow
        u = wrf.destagger(u,-1)-u2
        v = wrf.destagger(v,1)-v2
        w = wrf.destagger(w,0)-w2
    
        # Fire grid variables
        lfn1 = wrf.getvar(wrf_file, 'LFN', timeidx=-1)
        # loading lat/lons of fire mesh
        yf1 = wrf.getvar(wrf_file, 'FXLAT', timeidx=-1)
        xf1 = wrf.getvar(wrf_file, 'FXLONG', timeidx=-1)
    
        # subgrid ratio
        sr = int(wrf_file.dimensions['west_east_subgrid'].size)/int(wrf_file.dimensions['west_east_stag'].size)
        # removing the relaxation zones of the level-set function
        lfn1 = relax_zone_remover(lfn1, int(sr))
        xf1 = relax_zone_remover(xf1, int(sr))
        yf1 = relax_zone_remover(yf1, int(sr))
        # match lfn to met grid
        lfn2 = lfn1[::int(sr),::int(sr)]
        # remove weird perimeters before fire starts
        if (f < start_thresh):
            lfn2[:,:] = np.nan
            lfn1[:,:] = np.nan
    
        print('Interp x-sections...', end="", flush=True)
        
        # how many slices to take on either side of center x-section (0)
        # 51 total, 25 +/- centerline
        
        # interval to skip
        interv = 2
        x_idx = np.arange(int(xsect_val/-2.),int(xsect_val/2)+1,1)*interv
        y_idx = np.arange(int(xsect_val/-2.),int(xsect_val/2)+1,1)[::-1]*interv
        # x_idx = np.arange(-2,2,1)
        # y_idx = np.arange(-2,2,1)[::-1]
    
        if (AvgXsect == True):
            _lon_cross = []
            _lat_cross = []
            _lfn_cross = []
            _u_cross = []
            _v_cross = []
            _w_cross = []
            _smk_cross = []
            _qc_cross = []
            _qi_cross = []
            _ter_cross = []
            _comp_wind = []
            for ll in range(len(x_idx)):
                # print(ll)
                
                if (firename == 'bear_fire'):
                    # Central cross section BEAR
                    start_pt = wrf.ll_to_xy(wrf_file, 39.9, -120.92)
                    end_pt = wrf.ll_to_xy(wrf_file, 39.47, -121.45)
                if (firename == 'caldor_fire'):
                    # Central cross section CALDOR
                    start_pt = wrf.ll_to_xy(wrf_file, 38.77, -120.34)#38.83, -120.3
                    end_pt = wrf.ll_to_xy(wrf_file, 38.55, -120.68)
                
                # create list of start/end points to loop through
                startLat = lat[start_pt.data[0]+x_idx[ll],start_pt.data[1]+y_idx[ll]]
                startLon = lon[start_pt.data[0]+x_idx[ll],start_pt.data[1]+y_idx[ll]]
    
                endLat = lat[end_pt.data[0]+x_idx[ll],end_pt.data[1]+y_idx[ll]]
                endLon = lon[end_pt.data[0]+x_idx[ll],end_pt.data[1]+y_idx[ll]]
                
                # Define the cross section start and end points (NE to SW)
                start_point = wrf.CoordPair(lat=startLat, lon=startLon)
                end_point = wrf.CoordPair(lat=endLat, lon=endLon)
    
                # what height to interpolate x-sections to
                interp_height = np.arange(0,18030,30) # m
    
                # Fire vars
                lfn_cross = wrf.interpline(lfn2, wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = False)
                # Standard vars
                ter_cross = wrf.interpline(terrain, wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = False)
                lon_cross = wrf.interpline(lon, wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = False)
                lat_cross = wrf.interpline(lat, wrfin = wrf_file,
                                           start_point = start_point,
                                           end_point = end_point,
                                           latlon = True, meta = False)
                u_cross = wrf.vertcross(u, z, levels=interp_height,
                                        wrfin = wrf_file,
                                        start_point = start_point,
                                        end_point = end_point,
                                        latlon = True, meta = True)
                v_cross = wrf.vertcross(v, z, levels=interp_height,
                                        wrfin = wrf_file,
                                        start_point = start_point,
                                        end_point = end_point,
                                        latlon = True, meta = True)
                w_cross = wrf.vertcross(w, z, levels=interp_height,
                                        wrfin = wrf_file,
                                        start_point = start_point,
                                        end_point = end_point,
                                        latlon = True, meta = True)
                smk_cross = wrf.vertcross(smoke, z, levels=interp_height,
                                          wrfin = wrf_file,
                                          start_point = start_point,
                                          end_point = end_point,
                                          latlon = True, meta = True)
                qc_cross = wrf.vertcross(qc, z, levels=interp_height,
                                          wrfin = wrf_file,
                                          start_point = start_point,
                                          end_point = end_point,
                                          latlon = True, meta = True)
                qi_cross = wrf.vertcross(qi, z, levels=interp_height,
                                          wrfin = wrf_file,
                                          start_point = start_point,
                                          end_point = end_point,
                                          latlon = True, meta = True)
    
                # Flip everything so NE is on the right and SW is on the left
                lfn_cross = lfn_cross[::-1]
                lon_cross = lon_cross[::-1]
                lat_cross = lat_cross[::-1]
                coord_pairs = zip(lon_cross, lat_cross) # x labels for plotting
                x_array = np.arange(len(ter_cross)) # x coordinates for plotting
                ter_cross = ter_cross[::-1]
                u_cross = u_cross[:,::-1] 
                v_cross = v_cross[:,::-1] 
                w_cross = w_cross[:,::-1]
                smk_cross = smk_cross[:,::-1]
                qc_cross = qc_cross[:,::-1]
                qi_cross = qi_cross[:,::-1]
    
                # Calculate bearing for in-plane winds
                geodesic = pyproj.Geod(ellps='WGS84')
                fwd_azimuth,back_azimuth,distance = geodesic.inv(start_point.lon, start_point.lat,
                                                                 end_point.lon, end_point.lat)
                # Calculate in-plane wind
                component_wind = planarWind(u_cross,v_cross,back_azimuth)
                
                _lon_cross.append(lon_cross)
                _lat_cross.append(lat_cross)
                _lfn_cross.append(lfn_cross)
                _ter_cross.append(ter_cross)
                _u_cross.append(u_cross)
                _v_cross.append(v_cross)
                _w_cross.append(w_cross)
                _smk_cross.append(smk_cross)
                _qc_cross.append(qc_cross)
                _qi_cross.append(qi_cross)
                _comp_wind.append(component_wind)
                
                print(str(ll)+'.', end="", flush=True)
    
        # wrf cross section interpolates to nearest points so array lengths may 
        # vary by 1 or 2 points. This trims excess points off longer x-sections
        # (if they exist) so all dimensions match for calculations/stacking.    
        lens = []
        for i in _lon_cross:
            lens.append(len(i))
        lens = np.asarray(lens)
        minlength = np.nanmin(lens)
        # trim arrays to length of smallest array
        for i in range(len(_lon_cross)):
            _lon_cross[i] = _lon_cross[i][:minlength]
            _lat_cross[i] = _lat_cross[i][:minlength]
            _lfn_cross[i] = _lfn_cross[i][:minlength]
            _ter_cross[i] = _ter_cross[i][:minlength]
            _u_cross[i] = _u_cross[i][:,:minlength]
            _v_cross[i] = _v_cross[i][:,:minlength]
            _w_cross[i] = _w_cross[i][:,:minlength]
            _smk_cross[i] = _smk_cross[i][:,:minlength]
            _qc_cross[i] = _qc_cross[i][:,:minlength]
            _qi_cross[i] = _qi_cross[i][:,:minlength]
            _comp_wind[i] = _comp_wind[i][:,:minlength]
    
        # Turn lists of arrays into multi-dim arrays
        _lon_cross = np.stack(_lon_cross, axis=0)
        _lat_cross = np.stack(_lat_cross, axis=0)
        _lfn_cross = np.stack(_lfn_cross, axis=0)
        _ter_cross = np.stack(_ter_cross, axis=0)
        _u_cross = np.stack(_u_cross, axis=0)
        _v_cross = np.stack(_v_cross, axis=0)
        _w_cross = np.stack(_w_cross, axis=0)
        _smk_cross = np.stack(_smk_cross, axis=0)
        _qc_cross = np.stack(_qc_cross, axis=0)
        _qi_cross = np.stack(_qi_cross, axis=0)
        _comp_wind = np.stack(_comp_wind, axis=0)
        
        # find means/max
        lon_cross = np.nanmean(_lon_cross,axis=0).squeeze()
        lat_cross = np.nanmean(_lat_cross,axis=0).squeeze()
        lfn_cross = np.nanmin(_lfn_cross,axis=0).squeeze()
        ter_cross = np.nanmean(_ter_cross,axis=0).squeeze()
        u_cross = np.nanmean(_u_cross,axis=0).squeeze()
        v_cross = np.nanmean(_v_cross,axis=0).squeeze()
        w_cross = np.nanmax(_w_cross,axis=0).squeeze()
        smk_cross = np.nanmax(_smk_cross,axis=0).squeeze()
        qc_cross = np.nanmax(_qc_cross,axis=0).squeeze()
        qi_cross = np.nanmax(_qi_cross,axis=0).squeeze()
        comp_wind = np.nanmean(_comp_wind,axis=0).squeeze()
        x_array = np.arange(len(ter_cross))
        
        g = pyproj.Geod(ellps='WGS84')
        x_dist_array = []
        for l in range(len(lon_cross)):
            # 2D distance in meters with longitude, latitude of the points
            azimuth1, azimuth2, distance_2d = g.inv(lon_cross[0], 
                                                    lat_cross[0], 
                                                    lon_cross[l], lat_cross[l]) 
            # print(distance_2d)
            distance_2d = distance_2d/1000. # convert to km
                
            x_dist_array.append(distance_2d)
        x_dist_array = np.asarray(x_dist_array)
        
        # Cross section and coords of fire area
        lfn_x = x_dist_array[lfn_cross < 0]
        lfn_y = ter_cross[lfn_cross < 0]
        lfn_cross = lfn_cross[lfn_cross < 0]
        
        # x-sect data
        plot_w_cross.append(w_cross)
        plot_smk_cross.append(smk_cross)
        plot_qc_cross.append(qc_cross)
        plot_qi_cross.append(qi_cross)
        plot_comp_wind.append(comp_wind)
        plot_lfn_cross.append(lfn_cross)
        plot_lfn_x.append(lfn_x)
        plot_lfn_y.append(lfn_y)
        
    #%%
    #########################
    # Plotting
    #########################
    print('Plotting...', end="", flush=True)
    fig = plt.figure(1,figsize=(10,5))
    spec = gridspec.GridSpec(nrows=22, ncols=40, figure=fig)
    ax1 = fig.add_subplot(spec[:12,:12]) #rows,cols
    ax2 = fig.add_subplot(spec[:12,13:25]) #rows,cols
    ax3 = fig.add_subplot(spec[:12,26:38]) #rows,cols
    ax4 = fig.add_subplot(spec[14:20,28:34], projection=ccrs.PlateCarree()) #rows,cols
    # Colorbar
    ax5 = fig.add_subplot(spec[16,8:23]) 
    
    axvals = [ax1, ax2, ax3]
    for xx in range(len(plot_w_cross)):
        
        axVal = axvals[xx]
    
        # plot updrafts
        if (plot_type=='w'):
            w_lvls = np.arange(-21,21.5,2)
            # print(w_lvls)
            p_wind = axVal.contourf(x_dist_array, interp_height, plot_w_cross[xx], cmap='seismic',
                                  extend='both', levels=w_lvls, alpha=.8, zorder=0)
            p_smk = axVal.contour(x_dist_array,interp_height,plot_smk_cross[xx],levels=[.0001], #.0001 ug/kg
                                linewidths=3,colors='k',alpha=.5,zorder=1)
            p_qc = axVal.contour(x_dist_array,interp_height,plot_qc_cross[xx],levels=[.00001], #.001 g/kg
                                linewidths=3,colors='navy',alpha=.9,zorder=15)
            p_qi = axVal.contour(x_dist_array,interp_height,plot_qi_cross[xx],levels=[.00001], #.001 g/kg
                                linewidths=3,colors='dodgerblue',alpha=.9,zorder=15)
            
        # plot smoke
        if (plot_type=='smk'):
            dbz = 10*np.log10(plot_smk_cross[xx])
            dbz = np.ma.masked_invalid(dbz)
            # dbz_n = (dbz - np.min(dbz)) / (np.max(dbz) - np.min(dbz)) * 60. #dbz
            import matplotlib.cm as cm
            cmap1 = cm.get_cmap('binary',lut=16)
            p_smk = axVal.pcolormesh(x_dist_array,interp_height,dbz,cmap=cmap1,
                                    vmin=-400, vmax=-10,
                                    # norm=colors.LogNorm(vmin=-400,
                                    #                     vmax=-5),
                                    alpha=.75)
            
            # Colorbar
            cbar = plt.colorbar(p_smk)
            # cbar.set_ticks(np.arange(-5,6,1))
            # cbar.set_ticklabels(np.arange(-5,6,1))
        
        # Plot wind vectors
        # Regrid vertical/horizontal dims accordingly
        regrid_factX = 30 # horizontal regrid
        regrid_factY = 15 # vertical regrid
        # vectors
        wf = axVal.quiver(x_dist_array[::regrid_factX], interp_height[::regrid_factY],
                       plot_comp_wind[xx][::regrid_factY,::regrid_factX],
                       plot_w_cross[xx][::regrid_factY,::regrid_factX],
                       color='k', scale=100, width=0.005, pivot='tail', alpha=0.9, zorder=0)
        
        if (axVal == ax3):
            cbar = fig.colorbar(p_wind, orientation='horizontal', cax=ax5)
            cbar.set_ticks((w_lvls+1)[::2])
            cbar.set_ticklabels((w_lvls+1)[::2].astype(int))
            cbar.set_label(r'Vertical Velocity [m $\mathregular{s^{-1}}$]', fontsize=10, fontweight='heavy')
            # quiver key
            scale_vel = 10
            labstr = r''+str(scale_vel)+' m $\mathregular{s^{-1}}$'
            qk = axVal.quiverkey(wf, -2.15, -.47, scale_vel, labstr, labelpos='E',
                                 coordinates='axes',fontproperties=dict(size=12))
    
        # plot max velocity
        axVal.text(xyloc[0],xyloc[1],'Wmax: '+str(int(np.nanmax(plot_w_cross[xx])))+' m $\mathregular{s^{-1}}$', horizontalalignment=xyloc[2], fontsize=10, zorder=12)
        # plot terrain
        axVal.plot(x_dist_array,ter_cross,color='k',linewidth=4.,zorder=2)
        axVal.fill_between(x_dist_array,ter_cross,y2=0, color='sienna', zorder=1)
        # Plot fire area
        axVal.plot(plot_lfn_x[xx],plot_lfn_y[xx],color='k',linewidth=7.,zorder=3)
        axVal.plot(plot_lfn_x[xx],plot_lfn_y[xx],color='orangered',linewidth=4.,zorder=4)

        # # Set the x-ticks to use latitude and longitude labels.
        # tick_space = 90 #higher values = fewer ticks
        x_ticks = np.arange(x_dist_array[0],int(x_dist_array[-1]),tick_dist_interval).astype(int)
        # x_labels = ['{:.2f}, {:.2f}'.format(l1,l2) for l1,l2 in coord_pairs]
        axVal.set_xticks(x_ticks)
        axVal.set_xticklabels(x_ticks, rotation=0, horizontalalignment='center')
        # Set x/y lims
        axVal.set_xlim(x_dist_array[0],x_dist_array[-1])
        axVal.set_ylim(0,maxY)
        # Remove labels from inner panels
        if ((axVal == ax2) or (axVal == ax3)):
            axVal.set_yticklabels([])
    
        # Set border and ticks around plot
        for axis in ['top', 'bottom', 'left', 'right']:
            axVal.spines[axis].set_linewidth(2)  # change width
            axVal.spines[axis].set_zorder(99)

        axVal.xaxis.set_tick_params(width=2,length=5,direction="in", labelsize=10, zorder=99, bottom=True, top=True)
        axVal.yaxis.set_tick_params(width=2,length=5,direction="in", labelsize=10, zorder=99, left=True, right=True)
        
        # Title and axis labels
        if (axVal == ax1):
            axVal.set_ylabel("Height MSL [m]", fontweight='bold', fontsize=12)
        if (axVal == ax2):
            # axVal.set_title(tstamp+' UTC', loc='center', fontsize=14, fontweight='bold',zorder=10)
            plt.suptitle(tstamp+' UTC', fontsize=12, fontweight='bold', zorder=10)
            axVal.set_xlabel("X-Distance [km]", fontweight='bold', fontsize=12)

    ######################
    # Inset map
    # ax2 = fig.add_subplot(spec[0:13,21:31], projection=ccrs.PlateCarree()) # spec[rows,cols]
    extent = [lon[0,0], lon[0,-1], lat[0,0], lat[-1,0]]
    ax4.set_extent(extent, crs=ccrs.PlateCarree())
    
    # terrain in inset
    terMap = ax4.contourf(lon,lat,terrain, cmap='terrain', extend='both',
                          levels=np.arange(-400,2400,100), zorder=0,
                          transform=ccrs.PlateCarree())
    
    # # TESTING
    # qcMap = ax4.contour(lon,lat,qc[32,:,:], color='navy', linewidth=3,
    #                       levels=np.arange(0,1,.001), zorder=10,
    #                       transform=ccrs.PlateCarree())
    
    # fire perimeter in inset
    perimPlot = ax4.contour(xf1, yf1, lfn1, levels=[0], colors='k', linewidth=3,
                            zorder=1, transform=ccrs.PlateCarree())
    # plot extent of x-sections (dashed lines on inset)
    for jjj in [0,-1]:
        Xsect = ax4.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=3,color='white',
                         alpha=.5, zorder=2, transform=ccrs.PlateCarree())
        Xsect = ax4.plot(_lon_cross[jjj,:],_lat_cross[jjj,:],linewidth=1,color='k',
                         linestyle='--', zorder=2, transform=ccrs.PlateCarree())

    ax4.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
    ax4.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
    lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
    lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
    ax4.xaxis.set_major_formatter(lon_formatter) # set lons
    ax4.yaxis.set_major_formatter(lat_formatter) # set lats
    ax4.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=7, rotation=20, bottom=True, top=True, zorder=99)
    ax4.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=7, rotation=20, left=True, right=True, zorder=99)
    
    #Download county shapefiles
    reader = shpreader.Reader('/Users/matthewroberts/Documents/Data/shapefiles/countyl010g_shp/countyl010g.shp')
    counties = list(reader.geometries())
    COUNTIES = cartopy.feature.ShapelyFeature(counties, ccrs.PlateCarree())
    ax4.add_feature(COUNTIES, facecolor='none', edgecolor='k', linewidth=1, alpha=.4, zorder=10)
    
    # plt.show()
    # sys.exit()

    # create timestamp for saving files
    savetime = dt.datetime.strptime(tstamp, "%Y-%m-%d %H:%M:%S")
    savetime = dt.datetime.strftime(savetime, "%Y%m%d_%H%M")
    # save
    plt.savefig(savepath+'/'+keyword+'_'+savetime+'.png',bbox_inches='tight',dpi=260)
    plt.close('all')
    print('Saved '+keyword+'_'+savetime+'.png')
    # sys.exit()


end_time = dt.datetime.now()
print('===================')
print('Total elapsed time: {0}\n'.format(end_time-start_time))